; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple=x86_64-unknown-unknown --fp-contract=fast --enable-no-signed-zeros-fp-math -mattr=avx512fp16 | FileCheck %s --check-prefixes=CHECK,NO-SZ
; RUN: llc < %s -mtriple=x86_64-unknown-unknown --fp-contract=fast -mattr=avx512fp16 | FileCheck %s --check-prefixes=CHECK,HAS-SZ

; FADD(acc, FMA(a, b, +0.0)) can be combined to FMA(a, b, acc) if the nsz flag set.
define dso_local <32 x half> @test1(<32 x half> %acc, <32 x half> %a, <32 x half> %b) {
; NO-SZ-LABEL: test1:
; NO-SZ:       # %bb.0: # %entry
; NO-SZ-NEXT:    vfcmaddcph %zmm2, %zmm1, %zmm0
; NO-SZ-NEXT:    retq
;
; HAS-SZ-LABEL: test1:
; HAS-SZ:       # %bb.0: # %entry
; HAS-SZ-NEXT:    vxorps %xmm3, %xmm3, %xmm3
; HAS-SZ-NEXT:    vfcmaddcph %zmm2, %zmm1, %zmm3
; HAS-SZ-NEXT:    vaddph %zmm0, %zmm3, %zmm0
; HAS-SZ-NEXT:    retq
entry:
  %0 = bitcast <32 x half> %a to <16 x float>
  %1 = bitcast <32 x half> %b to <16 x float>
  %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float> %0, <16 x float> %1, <16 x float> zeroinitializer, i16 -1, i32 4)
  %3 = bitcast <16 x float> %2 to <32 x half>
  %add.i = fadd <32 x half> %3, %acc
  ret <32 x half> %add.i
}

define dso_local <32 x half> @test2(<32 x half> %acc, <32 x half> %a, <32 x half> %b) {
; NO-SZ-LABEL: test2:
; NO-SZ:       # %bb.0: # %entry
; NO-SZ-NEXT:    vfmaddcph %zmm2, %zmm1, %zmm0
; NO-SZ-NEXT:    retq
;
; HAS-SZ-LABEL: test2:
; HAS-SZ:       # %bb.0: # %entry
; HAS-SZ-NEXT:    vxorps %xmm3, %xmm3, %xmm3
; HAS-SZ-NEXT:    vfmaddcph %zmm2, %zmm1, %zmm3
; HAS-SZ-NEXT:    vaddph %zmm0, %zmm3, %zmm0
; HAS-SZ-NEXT:    retq
entry:
  %0 = bitcast <32 x half> %a to <16 x float>
  %1 = bitcast <32 x half> %b to <16 x float>
  %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float> %0, <16 x float> %1, <16 x float> zeroinitializer, i16 -1, i32 4)
  %3 = bitcast <16 x float> %2 to <32 x half>
  %add.i = fadd <32 x half> %3, %acc
  ret <32 x half> %add.i
}

define dso_local <16 x half> @test3(<16 x half> %acc, <16 x half> %a, <16 x half> %b) {
; NO-SZ-LABEL: test3:
; NO-SZ:       # %bb.0: # %entry
; NO-SZ-NEXT:    vfcmaddcph %ymm2, %ymm1, %ymm0
; NO-SZ-NEXT:    retq
;
; HAS-SZ-LABEL: test3:
; HAS-SZ:       # %bb.0: # %entry
; HAS-SZ-NEXT:    vxorps %xmm3, %xmm3, %xmm3
; HAS-SZ-NEXT:    vfcmaddcph %ymm2, %ymm1, %ymm3
; HAS-SZ-NEXT:    vaddph %ymm0, %ymm3, %ymm0
; HAS-SZ-NEXT:    retq
entry:
  %0 = bitcast <16 x half> %a to <8 x float>
  %1 = bitcast <16 x half> %b to <8 x float>
  %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float> %0, <8 x float> %1, <8 x float> zeroinitializer, i8 -1)
  %3 = bitcast <8 x float> %2 to <16 x half>
  %add.i = fadd <16 x half> %3, %acc
  ret <16 x half> %add.i
}

define dso_local <16 x half> @test4(<16 x half> %acc, <16 x half> %a, <16 x half> %b) {
; NO-SZ-LABEL: test4:
; NO-SZ:       # %bb.0: # %entry
; NO-SZ-NEXT:    vfmaddcph %ymm2, %ymm1, %ymm0
; NO-SZ-NEXT:    retq
;
; HAS-SZ-LABEL: test4:
; HAS-SZ:       # %bb.0: # %entry
; HAS-SZ-NEXT:    vxorps %xmm3, %xmm3, %xmm3
; HAS-SZ-NEXT:    vfmaddcph %ymm2, %ymm1, %ymm3
; HAS-SZ-NEXT:    vaddph %ymm0, %ymm3, %ymm0
; HAS-SZ-NEXT:    retq
entry:
  %0 = bitcast <16 x half> %a to <8 x float>
  %1 = bitcast <16 x half> %b to <8 x float>
  %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float> %0, <8 x float> %1, <8 x float> zeroinitializer, i8 -1)
  %3 = bitcast <8 x float> %2 to <16 x half>
  %add.i = fadd <16 x half> %3, %acc
  ret <16 x half> %add.i
}

define dso_local <8 x half> @test5(<8 x half> %acc, <8 x half> %a, <8 x half> %b) {
; NO-SZ-LABEL: test5:
; NO-SZ:       # %bb.0: # %entry
; NO-SZ-NEXT:    vfcmaddcph %xmm2, %xmm1, %xmm0
; NO-SZ-NEXT:    retq
;
; HAS-SZ-LABEL: test5:
; HAS-SZ:       # %bb.0: # %entry
; HAS-SZ-NEXT:    vxorps %xmm3, %xmm3, %xmm3
; HAS-SZ-NEXT:    vfcmaddcph %xmm2, %xmm1, %xmm3
; HAS-SZ-NEXT:    vaddph %xmm0, %xmm3, %xmm0
; HAS-SZ-NEXT:    retq
entry:
  %0 = bitcast <8 x half> %a to <4 x float>
  %1 = bitcast <8 x half> %b to <4 x float>
  %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float> %0, <4 x float> %1, <4 x float> zeroinitializer, i8 -1)
  %3 = bitcast <4 x float> %2 to <8 x half>
  %add.i = fadd <8 x half> %3, %acc
  ret <8 x half> %add.i
}

define dso_local <8 x half> @test6(<8 x half> %acc, <8 x half> %a, <8 x half> %b) {
; NO-SZ-LABEL: test6:
; NO-SZ:       # %bb.0: # %entry
; NO-SZ-NEXT:    vfmaddcph %xmm2, %xmm1, %xmm0
; NO-SZ-NEXT:    retq
;
; HAS-SZ-LABEL: test6:
; HAS-SZ:       # %bb.0: # %entry
; HAS-SZ-NEXT:    vxorps %xmm3, %xmm3, %xmm3
; HAS-SZ-NEXT:    vfmaddcph %xmm2, %xmm1, %xmm3
; HAS-SZ-NEXT:    vaddph %xmm0, %xmm3, %xmm0
; HAS-SZ-NEXT:    retq
entry:
  %0 = bitcast <8 x half> %a to <4 x float>
  %1 = bitcast <8 x half> %b to <4 x float>
  %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float> %0, <4 x float> %1, <4 x float> zeroinitializer, i8 -1)
  %3 = bitcast <4 x float> %2 to <8 x half>
  %add.i = fadd <8 x half> %3, %acc
  ret <8 x half> %add.i
}

; FADD(acc, FMA(a, b, -0.0)) can be combined to FMA(a, b, acc) no matter if the nsz flag set.
define dso_local <32 x half> @test13(<32 x half> %acc, <32 x half> %a, <32 x half> %b) {
; CHECK-LABEL: test13:
; CHECK:       # %bb.0: # %entry
; CHECK-NEXT:    vfcmaddcph %zmm2, %zmm1, %zmm0
; CHECK-NEXT:    retq
entry:
  %0 = bitcast <32 x half> %a to <16 x float>
  %1 = bitcast <32 x half> %b to <16 x float>
  %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float> %0, <16 x float> %1, <16 x float> <float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000>, i16 -1, i32 4)
  %3 = bitcast <16 x float> %2 to <32 x half>
  %add.i = fadd <32 x half> %3, %acc
  ret <32 x half> %add.i
}

define dso_local <32 x half> @test14(<32 x half> %acc, <32 x half> %a, <32 x half> %b) {
; CHECK-LABEL: test14:
; CHECK:       # %bb.0: # %entry
; CHECK-NEXT:    vfmaddcph %zmm2, %zmm1, %zmm0
; CHECK-NEXT:    retq
entry:
  %0 = bitcast <32 x half> %a to <16 x float>
  %1 = bitcast <32 x half> %b to <16 x float>
  %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float> %0, <16 x float> %1, <16 x float> <float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000>, i16 -1, i32 4)
  %3 = bitcast <16 x float> %2 to <32 x half>
  %add.i = fadd <32 x half> %3, %acc
  ret <32 x half> %add.i
}

define dso_local <16 x half> @test15(<16 x half> %acc, <16 x half> %a, <16 x half> %b) {
; CHECK-LABEL: test15:
; CHECK:       # %bb.0: # %entry
; CHECK-NEXT:    vfcmaddcph %ymm2, %ymm1, %ymm0
; CHECK-NEXT:    retq
entry:
  %0 = bitcast <16 x half> %a to <8 x float>
  %1 = bitcast <16 x half> %b to <8 x float>
  %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float> %0, <8 x float> %1, <8 x float> <float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000>, i8 -1)
  %3 = bitcast <8 x float> %2 to <16 x half>
  %add.i = fadd <16 x half> %3, %acc
  ret <16 x half> %add.i
}

define dso_local <16 x half> @test16(<16 x half> %acc, <16 x half> %a, <16 x half> %b) {
; CHECK-LABEL: test16:
; CHECK:       # %bb.0: # %entry
; CHECK-NEXT:    vfmaddcph %ymm2, %ymm1, %ymm0
; CHECK-NEXT:    retq
entry:
  %0 = bitcast <16 x half> %a to <8 x float>
  %1 = bitcast <16 x half> %b to <8 x float>
  %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float> %0, <8 x float> %1, <8 x float> <float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000>, i8 -1)
  %3 = bitcast <8 x float> %2 to <16 x half>
  %add.i = fadd <16 x half> %3, %acc
  ret <16 x half> %add.i
}

define dso_local <8 x half> @test17(<8 x half> %acc, <8 x half> %a, <8 x half> %b) {
; CHECK-LABEL: test17:
; CHECK:       # %bb.0: # %entry
; CHECK-NEXT:    vfcmaddcph %xmm2, %xmm1, %xmm0
; CHECK-NEXT:    retq
entry:
  %0 = bitcast <8 x half> %a to <4 x float>
  %1 = bitcast <8 x half> %b to <4 x float>
  %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float> %0, <4 x float> %1, <4 x float> <float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000>, i8 -1)
  %3 = bitcast <4 x float> %2 to <8 x half>
  %add.i = fadd <8 x half> %3, %acc
  ret <8 x half> %add.i
}

define dso_local <8 x half> @test18(<8 x half> %acc, <8 x half> %a, <8 x half> %b) {
; CHECK-LABEL: test18:
; CHECK:       # %bb.0: # %entry
; CHECK-NEXT:    vfmaddcph %xmm2, %xmm1, %xmm0
; CHECK-NEXT:    retq
entry:
  %0 = bitcast <8 x half> %a to <4 x float>
  %1 = bitcast <8 x half> %b to <4 x float>
  %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float> %0, <4 x float> %1, <4 x float> <float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000>, i8 -1)
  %3 = bitcast <4 x float> %2 to <8 x half>
  %add.i = fadd <8 x half> %3, %acc
  ret <8 x half> %add.i
}

declare <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float>, <16 x float>, <16 x float>, i16, i32 immarg)
declare <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float>, <16 x float>, <16 x float>, i16, i32 immarg)
declare <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float>, <8 x float>, <8 x float>, i8)
declare <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float>, <8 x float>, <8 x float>, i8)
declare <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float>, <4 x float>, <4 x float>, i8)
declare <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float>, <4 x float>, <4 x float>, i8)
